from datasets import load_dataset
from data_prep.BaseDatasetProcessor import BaseDatasetProcessor, DEFAULT_PROMPT_TEMPLATE
from functools import partial
import math


def safe_convert_to_float(string_value):
    if string_value == 'N/A':
        return float('nan')
    return float(string_value)


def ultrafeedback_transform_to_preference(batched_sample, dimension, threshold):
    def chosen_id(score1, score2, threshold):
        """
           Compare two scores based on a given threshold and return 1 or 0.

           Parameters:
           - score1: The first score to compare.
           - score2: The second score to compare.
           - threshold: The difference threshold for comparison.

           Returns:
           - 1 if the absolute difference between score1 and score2 is greater or equal to the threshold
             and score1 is higher than score2.
           - -1 if the absolute difference is smaller than to the threshold,
           -  0 if score2 is higher than score1.
           """
        assert threshold > 0, "The threshold must be a positive number."
        # Convert scores to floats
        score1_float = safe_convert_to_float(score1)
        score2_float = safe_convert_to_float(score2)

        # Check if either score could not be converted
        if math.isnan(score1_float) or math.isnan(score2_float):
            return -9999  # Return a sentinel value indicating invalid input

        difference = abs(score1_float - score2_float)

        if difference >= threshold:
            return 0 if score1_float > score2_float else 1
        else:
            return -1

    finegrained_dimensions = ("instruction_following", "honesty", "truthfulness", "helpfulness")
    dimensions = finegrained_dimensions + ("overall",)

    new_batched_sample = {
        "prompt": [],
        "response_0": [],
        "response_1": [],
        f"{dimension}_chosen_id" :[],
    }
    for instruction, completions in zip(batched_sample["instruction"], batched_sample["completions"]):
        if not completions:
            continue

        # Find max and min based on overall_score
        if dimension == "overall":
            max_completion = max(completions, key=lambda c: c["overall_score"])
            min_completion = min(completions, key=lambda c: c["overall_score"])
            val = chosen_id(
                min_completion["overall_score"],
                max_completion["overall_score"],
                threshold
            )
            if val is None or val < 0:
                continue

            new_batched_sample["prompt"].append(instruction)
            new_batched_sample["response_0"].append(min_completion['response'])
            new_batched_sample["response_1"].append(max_completion['response'])

            new_batched_sample["overall_chosen_id"].append(
                val
            )
        
        elif dimension in finegrained_dimensions:
            max_completion = max(completions, key=lambda c: c["annotations"][dimension]["Rating"])
            min_completion = min(completions, key=lambda c: c["annotations"][dimension]["Rating"])

            val = chosen_id(
                    min_completion["annotations"][dimension]["Rating"],
                    max_completion["annotations"][dimension]["Rating"],
                    threshold
            )
            if val is None or val < 0:
                continue

            new_batched_sample["prompt"].append(instruction)
            new_batched_sample["response_0"].append(min_completion['response'])
            new_batched_sample["response_1"].append(max_completion['response'])
            new_batched_sample[f"{dimension}_chosen_id"].append(
                val
            )

    return new_batched_sample


class UltraFeedbackRDPProcessorBinary(BaseDatasetProcessor):
    # Static dictionaries for different datasets
    dataset_name = "openbmb/UltraFeedback"
    dimensions = {"instruction_following", "honesty", "truthfulness", "helpfulness", "overall"}
    '''
    SCHEMA = {
        'prompt': 'prompt',
        'chosen': 'response_0',
        'rejected': 'response_1',
    }
    # Dynamically create mappings for each dimension
    SCHEMA.update({
        f'{dimension}_chosen_id': f'chosen_id_dim_{i + 1}'
        for i, dimension in enumerate(dimensions)
    })
    '''
    # update threshold can change the dataset size
    def __init__(self,
                 prompt_template=DEFAULT_PROMPT_TEMPLATE,
                 num_proc=4, sanity_check=False, threshold=1):
        super().__init__(num_proc, sanity_check, prompt_template)
        self.threshold = threshold
        print(f"UltraFeedbackRDPProcessor initialized with threshold {self.threshold}")

    def _dataset_to_preference_formatter(self, example, dimension):
        chosen_id = example[f"{dimension}_chosen_id"]
        return {
            "prompt":   self.prompt_template.format(raw_prompt=example["prompt"]),
            "chosen":   example[f"response_{chosen_id}"],
            "rejected": example[f"response_{1-chosen_id}"],
        }

    def get_preference_dataset(self, split, seed, removed_dimensions = None):
        """
        Load and process a dataset based on its name and specified configurations.

        Parameters:
            dataset_name (str): The name of the dataset to load.
            split (str): The split of the dataset to load (e.g., 'train', 'test').

        Returns:
            dataset: The processed dataset.
        """
        dataset = self.get_raw_dataset(split, seed)
        # Inspect the columns in the training split
        print("Original columns in the training split:")
        print(dataset)
        original_columns = dataset.column_names

        #for i in range(len(dataset)):
        #    print(dataset[i])

        # filter_row_function = partial(filter_row, dimensions=self.dimensions)
        # dataset = dataset.filter(lambda example: filter_row_function(example))
        dataset_dict = {}
        for dimension in self.dimensions:
            if removed_dimensions and dimension in removed_dimensions:
                print(f"skip dimension {dimension}")
                continue

            transformed_function = partial(ultrafeedback_transform_to_preference, dimension = dimension, threshold=self.threshold)

            newds = dataset.map(
                transformed_function,
                batched=True,
                num_proc=self.num_proc,
                remove_columns=original_columns,
            )
            print("mapping raw dataset to preference...")
            print(newds)

            transformed_function = partial(self._dataset_to_preference_formatter, dimension=dimension)
            dataset_dict[dimension] = newds.map(transformed_function, 
                                                num_proc=self.num_proc, 
                                                remove_columns=newds.column_names)
        print("Updated columns in the training split:")
        print(dataset_dict)
        return dataset_dict
